#!/usr/bin/env python
# -*- coding:utf-8 -*-
from __future__ import nested_scopes, generators, division, absolute_import, with_statement, print_function, \
    unicode_literals

import os
import yaml
from flask import Flask, request, json
from flask_cors import cross_origin

from util.model_factory import ModelFactory, Model
from const import DEFAULT_CONFIG, STATIC_PATH, MODEL_PATH

config = os.path.join(STATIC_PATH, 'conf.yaml')

# 解析配置文件
c = DEFAULT_CONFIG
with open(config, 'r') as fp:
    c.update(yaml.safe_load(fp))


class LSTMModel(Model):
    pass


LSTMModel.model_path = os.path.join(MODEL_PATH, '%s.yaml' % (c.get('fit').get('save').get('model'),))
LSTMModel.weights_path = os.path.join(MODEL_PATH, '%s.h5' % (c.get('fit').get('save').get('model'),))
LSTMModel.tokenize_path = os.path.join(MODEL_PATH, '%s.pkl' % (c.get('fit').get('save').get('tokenize'),))
LSTMModel.max_len = c.get('word2vec').get('max_len')

app = Flask(__name__)


@app.route('/api/labels', methods=['POST'])
@cross_origin()  # 允许跨域
def get_labels():
    """对多组数据分类的方法"""
    try:
        data = request.get_json()
    except Exception:
        return 'Unknown parameters', 400
    text = []
    for d in data:
        text.append(d['data'])
        del d['data']
    model = ModelFactory.get_model(LSTMModel)
    label = model.predict(text)
    for i in range(len(data)):
        data[i]['label'] = str(label[i] if label[i] != 2 else -1)
    j = json.dumps(data)
    return j, 200


@app.route('/api/label', methods=['POST'])
@cross_origin()  # 允许跨域
def get_label():
    """对一组数据分类的方法"""
    try:
        data = request.get_json()
    except Exception:
        return 'Unknown parameters', 400
    text = [data['data']]
    del data['data']
    model = ModelFactory.get_model(LSTMModel)
    label = model.predict(text)
    data['label'] = str(label[0] if label[0] != 2 else -1)
    j = json.dumps(data)
    return j, 200


@app.route('/api/update')
def update():
    ModelFactory.get_model(LSTMModel, update=True)
    return json.dumps({'data': 'Successfully updated model'}), 200


if __name__ == '__main__':
    app.run(host='localhost', port=5000)
